package nsalt.mem

import chisel3._
import chisel3.util._

import nsalt._
import nsalt.bus._

class TransLookasideTableWriteBundle (val IndexBits: Int, val Ways: Int, val tlbLen: Int) extends Bundle with Constants with VirtMemSchema {
  val valid = Output(Bool())
  val index = Output(UInt(IndexBits.W))
  val waymask = Output(UInt(Ways.W))
  val data = Output(UInt(tlbLen.W))
  
  def apply(valid: UInt, index: UInt, waymask: UInt, virtPageNum: UInt, asid: UInt, mask: UInt, flag: UInt, physPageNum: UInt, entryAddr: UInt) = {
    this.valid := valid
    this.index := index
    this.waymask := waymask
    this.data := Cat(virtPageNum, asid, mask, flag, physPageNum, entryAddr)
  }
}

class TransLookasideTable(implicit val conf: TransLookasideConf) extends Module with TransLookasideConst {
  
  val io = IO(new Bundle {
    val group  = Output(Vec(WAYS, UInt(tlbLen.W)))
    val write  = Flipped(new TransLookasideTableWriteBundle(IndexBits = IndexBits, Ways = WAYS, tlbLen = tlbLen))
    val indexR = Input(UInt(IndexBits.W))
    val ready  = Output(Bool())
  })

  //val table = Reg(Vec(Ways, UInt(tlbLen.W)))
  val table = Mem(Sets, Vec(WAYS, UInt(tlbLen.W)))
  io.group := table(io.indexR)

  val onReset = RegInit(true.B)//RegEnable(true.B, init = true.B, reset)
  val (resetIndex, resetDone) = Counter(onReset, Sets)
  when (resetDone) { onReset := false.B }

  val writeReady = io.write.valid//WireInit(false.B)
  val writeIndex = io.write.index
  val writeWayMask = io.write.waymask
  val writeData = io.write.data

  // val valid = Mux(onReset, true.B, writeWen)
  val setIdx   = Mux(onReset, resetIndex, writeIndex)
  val waymask  = Mux(onReset, Fill(WAYS, "b1".U), writeWayMask)
  val dataWord = Mux(onReset, 0.U, writeData)

  val data = VecInit(Seq.fill(WAYS)(dataWord))

  when (onReset || writeReady) {
    table.write(setIdx, data, waymask.asBools)
  }

  io.ready := !onReset
  // def rready() = !onReset
  // def wready() = !onReset
}